{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pynvml import *\n",
    "\n",
    "nvmlInit()\n",
    "vram = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(0)).free/1024.**2\n",
    "print('GPU0 Memory: %dMB' % vram)\n",
    "if vram < 8000:\n",
    "    raise Exception('GPU Memory too low')\n",
    "nvmlShutdown()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import h5py\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import *\n",
    "from collections import Counter\n",
    "import seaborn as sns\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import re\n",
    "import time\n",
    "import random\n",
    "\n",
    "from keras.layers import *\n",
    "from keras.models import *\n",
    "from keras.optimizers import *\n",
    "from keras.regularizers import l2\n",
    "from keras.utils.vis_utils import model_to_dot\n",
    "import keras.backend as K\n",
    "from make_parallel import make_parallel\n",
    "from keras.backend.tensorflow_backend import set_session\n",
    "\n",
    "import tensorflow as tf\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth=True\n",
    "set_session(tf.Session(config=config))\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "IMAGE_DIR = 'image_contest_level_2'\n",
    "\n",
    "DESCRIPTION = 'l2加层生成器2'\n",
    "MODEL_NAME = 'model_%s_best.h5' % DESCRIPTION"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 载入基本数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "df = pd.read_csv('../../image_contest_level_2/labels.txt', sep=' ', header=None)\n",
    "characters = u'0123456789()+-*/=君不见黄河之水天上来奔流到海复回烟锁池塘柳深圳铁板烧; '\n",
    "\n",
    "labels_len = np.array(map(lambda x:len(x.decode('utf-8')), df[0]))\n",
    "n_len = 45\n",
    "n, width, height, n_class, channels = 100000, 900, 72, len(characters), 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def decode(out):\n",
    "    return ''.join([characters[x] for x in out if x < n_class-1 and x > -1])\n",
    "\n",
    "def disp2(img):\n",
    "    cv2.imwrite('a.png', img)\n",
    "    return display(Image('a.png'))\n",
    "\n",
    "def disp(img, txt=None, first=False):\n",
    "    global index\n",
    "    if first:\n",
    "        index = 1\n",
    "        plt.figure(figsize=(16, 9))\n",
    "    else:\n",
    "        index += 1\n",
    "    plt.subplot(4, 1, index)\n",
    "    if len(img.shape) == 2:\n",
    "        plt.imshow(img, cmap='gray')\n",
    "    else:\n",
    "        plt.imshow(img[:,:,::-1])\n",
    "    if txt:\n",
    "        plt.title(txt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ctc_lambda_func(args):\n",
    "    y_pred, labels, input_length, label_length = args\n",
    "    y_pred = y_pred[:, 2:, :]\n",
    "    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)\n",
    "\n",
    "rnn_size = 128\n",
    "\n",
    "l2_rate = 1e-5\n",
    "\n",
    "input_tensor = Input((width, height, 3))\n",
    "x = input_tensor\n",
    "for i, n_cnn in enumerate([2, 3, 6]):\n",
    "    for j in range(n_cnn):\n",
    "        x = Conv2D(32*2**i, (3, 3), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(l2_rate))(x)\n",
    "        x = BatchNormalization()(x)\n",
    "        x = Activation('relu')(x)\n",
    "    x = MaxPooling2D((2, 2))(x)\n",
    "\n",
    "cnn_model = Model(input_tensor, x, name='cnn')\n",
    "\n",
    "input_tensor = Input((width, height, 3))\n",
    "x = cnn_model(input_tensor)\n",
    "\n",
    "conv_shape = x.get_shape().as_list()\n",
    "rnn_length = conv_shape[1]\n",
    "rnn_dimen = conv_shape[3]*conv_shape[2]\n",
    "\n",
    "print conv_shape, rnn_length, rnn_dimen\n",
    "\n",
    "x = Reshape(target_shape=(rnn_length, rnn_dimen))(x)\n",
    "rnn_length -= 2\n",
    "\n",
    "x = Dense(rnn_size, kernel_initializer='he_normal', kernel_regularizer=l2(l2_rate))(x)\n",
    "x = BatchNormalization()(x)\n",
    "x = Activation('relu')(x)\n",
    "x = Dropout(0.25)(x)\n",
    "\n",
    "gru_1 = GRU(rnn_size, implementation=2, return_sequences=True, name='gru1')(x)\n",
    "gru_1b = GRU(rnn_size, implementation=2, return_sequences=True, go_backwards=True, name='gru1_b')(x)\n",
    "gru1_merged = add([gru_1, gru_1b])\n",
    "\n",
    "gru_2 = GRU(rnn_size, implementation=2, return_sequences=True, name='gru2')(gru1_merged)\n",
    "gru_2b = GRU(rnn_size, implementation=2, return_sequences=True, go_backwards=True, name='gru2_b')(gru1_merged)\n",
    "x = concatenate([gru_2, gru_2b])\n",
    "x = Dropout(0.25)(x)\n",
    "x = Dense(n_class, activation='softmax', kernel_regularizer=l2(l2_rate))(x)\n",
    "rnn_out = x\n",
    "base_model = Model(input_tensor, x)\n",
    "#base_model2 = make_parallel(base_model, 4)\n",
    "\n",
    "\n",
    "labels = Input(name='the_labels', shape=[n_len], dtype='float32')\n",
    "input_length = Input(name='input_length', shape=(1,), dtype='int64')\n",
    "label_length = Input(name='label_length', shape=(1,), dtype='int64')\n",
    "loss_out = Lambda(ctc_lambda_func, name='ctc')([base_model.output, labels, input_length, label_length])\n",
    "#loss_out = Lambda(ctc_lambda_func, name='ctc')([base_model2.output, labels, input_length, label_length])\n",
    "\n",
    "model = Model(inputs=(input_tensor, labels, input_length, label_length), outputs=loss_out)\n",
    "model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam')\n",
    "model.save('test.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "SVG(model_to_dot(base_model, show_shapes=True).create(prog='dot', format='svg'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 载入训练集做验证集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n1 = 350*256\n",
    "n2 = 390*256\n",
    "\n",
    "X = np.zeros((n2-n1, width, height, channels), dtype=np.uint8)\n",
    "y = np.zeros((n2-n1, n_len), dtype=np.uint8)\n",
    "\n",
    "for i in tqdm(range(n1, n2)):\n",
    "    img = cv2.imread('crop_split/%d.png'%i).transpose(1, 0, 2)\n",
    "    a, b, _ = img.shape\n",
    "    X[i-n1, :a, :b] = img\n",
    "    \n",
    "    label = df[0][i].decode('utf-8')\n",
    "    y[i-n1,:len(label)] = [characters.find(x) for x in label]\n",
    "    y[i-n1,len(label):] = n_class-1\n",
    "\n",
    "X_val = [X, y, np.ones(n2-n1)*rnn_length, labels_len[n1:n2]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 生成器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "cn_imgs = defaultdict(list)\n",
    "cn_labels = defaultdict(list)\n",
    "ss_imgs = []\n",
    "ss_labels = []\n",
    "\n",
    "for i in tqdm(range(n1)):\n",
    "    ss = df[0][i].decode('utf-8').split(';')\n",
    "    m = len(ss)-1\n",
    "    ss_labels.append(ss[-1])\n",
    "    ss_imgs.append(cv2.imread('crop_split/%d_%d.png'%(i, 0)).transpose(1, 0, 2))\n",
    "    for j in range(m):\n",
    "        cn_labels[ss[j][0]].append(ss[j])\n",
    "        cn_imgs[ss[j][0]].append(cv2.imread('crop_split/%d_%d.png'%(i, m-j)).transpose(1, 0, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.utils import Sequence\n",
    "\n",
    "class SGen(Sequence):\n",
    "    def __init__(self, batch_size):\n",
    "        self.batch_size = batch_size\n",
    "        self.X_gen = np.zeros((batch_size, width, height, 3), dtype=np.uint8)\n",
    "        self.y_gen = np.zeros((batch_size, n_len), dtype=np.uint8)\n",
    "        self.input_length = np.ones(batch_size)*rnn_length\n",
    "        self.label_length = np.ones(batch_size)*38\n",
    "    \n",
    "    def __len__(self):\n",
    "        return 350*256 // self.batch_size\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        random.seed(time.time()*100000+idx)\n",
    "        self.X_gen[:] = 0\n",
    "        for i in range(self.batch_size):\n",
    "            random_index = random.randint(0, n1-1)\n",
    "            cls = []\n",
    "            ss = ss_labels[random_index]\n",
    "            cs = re.findall(ur'[\\u4e00-\\u9fff]', df[0][random_index].decode('utf-8').split(';')[-1])\n",
    "            random.shuffle(cs)\n",
    "            x = 0\n",
    "            for c in cs:\n",
    "                random_index2 = random.randint(0, len(cn_labels[c])-1)\n",
    "                cls.append(cn_labels[c][random_index2])\n",
    "                img = cn_imgs[c][random_index2]\n",
    "                w, h, _ = img.shape\n",
    "                self.X_gen[i, x:x+w, :h] = img\n",
    "                x += w+2\n",
    "            img = ss_imgs[random_index]\n",
    "            w, h, _ = img.shape\n",
    "            self.X_gen[i, x:x+w, :h] = img\n",
    "            cls.append(ss)\n",
    "\n",
    "            random_str = u';'.join(cls)\n",
    "            self.y_gen[i,:len(random_str)] = [characters.find(x) for x in random_str]\n",
    "            self.y_gen[i,len(random_str):] = n_class-1\n",
    "            self.label_length[i] = len(random_str)\n",
    "        \n",
    "        return [self.X_gen, self.y_gen, self.input_length, self.label_length], np.ones(self.batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 可视化生成器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = SGen(1)\n",
    "(X_vis, y_vis, input_length_vis, label_length_vis), _ = gen[0]\n",
    "print input_length_vis[0]\n",
    "print label_length_vis[0]\n",
    "print decode(y_vis[0])\n",
    "disp2(X_vis[0].transpose((1, 0, 2)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def evaluate():\n",
    "    global out\n",
    "    #y_pred = base_model2.predict(X, batch_size=256, verbose=1)\n",
    "    y_pred = base_model.predict(X, batch_size=256, verbose=1)\n",
    "    out = K.get_value(K.ctc_decode(y_pred[:,2:], input_length=np.ones(y_pred.shape[0])*rnn_length)[0][0])[:, :n_len]\n",
    "    out[out == -1] = n_class-1\n",
    "    score = (out == y[:,:out.shape[-1]]).all(axis=-1).mean()\n",
    "    base_model.save('model_%s_%.6f.h5' % (DESCRIPTION, score))\n",
    "    return score\n",
    "\n",
    "def train(lr=1e-3, epochs=10, batch_size=128, callbacks=None):\n",
    "    opt = Adam(lr)\n",
    "    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=opt)\n",
    "    model.fit_generator(SGen(batch_size), steps_per_epoch=350*256/batch_size, \n",
    "                        workers=6, use_multiprocessing=True,\n",
    "                        validation_data=(X_val, np.ones(n2-n1)), epochs=epochs, callbacks=callbacks)\n",
    "    print evaluate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 采用每次训练最好的模型\n",
    "\n",
    "http://keras-cn.readthedocs.io/en/latest/other/callbacks/#modelcheckpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.callbacks import *\n",
    "\n",
    "mc = ModelCheckpoint(MODEL_NAME, save_best_only=True)\n",
    "cl = CSVLogger('model_%s.csv' % DESCRIPTION, append=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "train(epochs=20, callbacks=[mc, cl], batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.load_weights(MODEL_NAME)\n",
    "evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train(1e-4, callbacks=[mc, cl])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.load_weights(MODEL_NAME)\n",
    "evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "train(1e-5, callbacks=[mc, cl])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.load_weights(MODEL_NAME)\n",
    "evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train(1e-4, callbacks=[mc, cl])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.load_weights(MODEL_NAME)\n",
    "evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train(1e-5, callbacks=[mc, cl])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.load_weights(MODEL_NAME)\n",
    "evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "base_model.save('model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
